Module 5: Design¶
The aims of this lab are:
- Learn about
matplotlib's colormaps, including the awesomevidiris. - Learn how to adjust the design element of a basic plot in
matplotlib. - Understand the differences between bitmap and vector graphics.
- Learn what is SVG and how to create simple shapes in SVG.
First, import numpy and matplotlib libraries (don't forget the matplotlib inline magic command if you are using Jupyter Notebook).
import numpy as np
import matplotlib.pyplot as plt
# %matplotlib inline
Colors¶
We discussed colors for categorical and quantitative data. We can further specify the quantitative cases into sequential and diverging. "Sequential" means that the underlying value has a sequential ordering and the color also just needs to change sequentially and monotonically.
In the "diverging" case, there should be a meaningful anchor point. For instance, the correlation values may be positive or negative. Both large positive correlation and large negative correlation are important and the sign of the correlation has an important meaning. Therefore, we would like to stitch two sequential colormap together, one from zero to +1, the other from zero to -1.
Categorical (qualitative) colormaps¶
To experiment with colormpas, let's create some data first. We will use the numpy's random module to create some random data.
numpy¶
numpy is one of the most important packages in Python. As the name suggests (num + py), it handles all kinds of numerical manipulations and is the basis of pretty much all scientific packages. Actually, a pandas "series" is essentially a numpy array and a dataframe is essentially a bunch of numpy arrays grouped together. If you use it wisely, it can easily give you 10x, 100x or even 1000x speed-up!
If you use pandas or other packages, they may do all these numpy optimization under the hood for you. However, it is still good to know some basic numpy operations. If you want to study numpy more, check out the official tutorial and "From Python to Numpy" book:
Plotting some trigonometric functions¶
Let's plot a sine and cosine function. By the way, a common trick to plot a function is creating a list of x coordinate values (evenly spaced numbers over an interval) first. numpy has a function called linspace for that ("LINear SPACE"). By default, it creates 50 numbers that fill the interval that you pass.
np.linspace(0, 3)
array([0. , 0.06122449, 0.12244898, 0.18367347, 0.24489796,
0.30612245, 0.36734694, 0.42857143, 0.48979592, 0.55102041,
0.6122449 , 0.67346939, 0.73469388, 0.79591837, 0.85714286,
0.91836735, 0.97959184, 1.04081633, 1.10204082, 1.16326531,
1.2244898 , 1.28571429, 1.34693878, 1.40816327, 1.46938776,
1.53061224, 1.59183673, 1.65306122, 1.71428571, 1.7755102 ,
1.83673469, 1.89795918, 1.95918367, 2.02040816, 2.08163265,
2.14285714, 2.20408163, 2.26530612, 2.32653061, 2.3877551 ,
2.44897959, 2.51020408, 2.57142857, 2.63265306, 2.69387755,
2.75510204, 2.81632653, 2.87755102, 2.93877551, 3. ])
Let's just work with 10 numbers to make it easier to see.
a = np.linspace(0, 3, 10) # 10 numbers instead of 50
a
array([0. , 0.33333333, 0.66666667, 1. , 1.33333333,
1.66666667, 2. , 2.33333333, 2.66666667, 3. ])
A nice thing about numpy is that you can apply many mathematical operations as if you are dealing with a single number.
# add 1 to each element of the array
a_plus_1 = a + 1
print(a_plus_1)
# multiply each element of the array by 3
a_times_3 = a * 3
print(a_times_3)
# raise each element of the array to the power of 2
a_squared = a ** 2
print(a_squared)
# take the square root of each element of the array
a_sqrt = np.sqrt(a)
print(a_sqrt)
[1. 1.33333333 1.66666667 2. 2.33333333 2.66666667 3. 3.33333333 3.66666667 4. ] [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] [0. 0.11111111 0.44444444 1. 1.77777778 2.77777778 4. 5.44444444 7.11111111 9. ] [0. 0.57735027 0.81649658 1. 1.15470054 1.29099445 1.41421356 1.52752523 1.63299316 1.73205081]
These are called "vectorized" operations. Whenever you can, you should use vectorized operations instead of looping over the elements because they are way way faster and efficient.
Q: Let's plot some sin and cos functions.
use numpy's sin and cos functions with matplotlib's plot function to plot.
x = np.linspace(0, 3*np.pi)
# YOUR SOLUTION HERE
plt.plot(x, np.sin(x), label='sin(x)')
plt.plot(x, np.cos(x), label='cos(x)')
plt.legend()
plt.show()
matplotlib picks a pretty good color pair by default! Orange-blue pair is colorblind-safe and it is like the color pair of every movie.
matplotlib has many qualitative (categorical) colorschemes. https://matplotlib.org/users/colormaps.html

You can access them through the following ways:
plt.cm.Pastel1
or
pastel1 = plt.get_cmap('Pastel1')
pastel1
You can also see the colors in the colormap in RGB (remember what each number means?).
pastel1.colors
((0.984313725490196, 0.7058823529411765, 0.6823529411764706), (0.7019607843137254, 0.803921568627451, 0.8901960784313725), (0.8, 0.9215686274509803, 0.7725490196078432), (0.8705882352941177, 0.796078431372549, 0.8941176470588236), (0.996078431372549, 0.8509803921568627, 0.6509803921568628), (1.0, 1.0, 0.8), (0.8980392156862745, 0.8470588235294118, 0.7411764705882353), (0.9921568627450981, 0.8549019607843137, 0.9254901960784314), (0.9490196078431372, 0.9490196078431372, 0.9490196078431372))
To get the first and second colors, you can use either ways:
plt.plot(x, np.sin(x), color=plt.cm.Pastel1(0))
plt.plot(x, np.cos(x), color=pastel1(1))
[<matplotlib.lines.Line2D at 0x7f93acfb7c80>]
Q: pick a qualitative colormap and then draw four different curves with four different colors in the colormap.
Note that the colorschemes are not necessarily colorblindness-safe nor lightness-varied! Think about whether the colormap you chose is a good one or not based on the criteria that we discussed.
# TODO: put your code here
# YOUR SOLUTION HERE
x = np.linspace(0, 4 * np.pi)
cmap = plt.get_cmap('Set2')
plt.plot(x, np.sin(x), label='sin(x)', color=cmap(0), linewidth=3)
plt.plot(x, np.cos(x), label='cos(x)', color=cmap(1), linewidth=3)
plt.plot(x, np.sin(x+np.pi/4), label='sin(x+pi/4)', color=cmap(2), linewidth=3)
plt.plot(x, np.cos(x+np.pi/4), label='cos(x+pi/4)', color=cmap(3), linewidth=3)
# plt.legend()
plt.show()
cmap
Quantitative colormaps¶
Take a look at the tutorial about image processing in matplotlib: https://matplotlib.org/stable/tutorials/introductory/images.html#sphx-glr-tutorials-introductory-images-py
We can also display an image using quantitative (sequential) colormaps. Download the image of a snake: https://github.com/yy/dviz-course/blob/master/docs/m05-design/sneakySnake.png or use other image of your liking.
Check out imread() function that returns an numpy.array().
import matplotlib.image as mpimg
import PIL
import urllib
img = np.array(PIL.Image.open('sneakySnake.png'))
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f93ace0bbf0>
Alternatively, you can use BytesIO to read the image from the URL.
import requests
from io import BytesIO
url = 'https://raw.githubusercontent.com/yy/dviz-course/master/docs/m05-design/sneakySnake.png'
# Fetch the image from the URL
response = requests.get(url)
img = np.array(PIL.Image.open(BytesIO(response.content)))
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f93aceadd00>
How is the image stored?
img
array([[[ 39, 55, 36, 255],
[ 42, 58, 40, 255],
[ 44, 63, 37, 255],
...,
[ 48, 57, 45, 255],
[ 48, 61, 46, 255],
[ 56, 76, 54, 255]],
[[ 48, 65, 45, 255],
[ 47, 67, 42, 255],
[ 77, 103, 64, 255],
...,
[ 47, 54, 46, 255],
[ 46, 56, 47, 255],
[ 44, 58, 45, 255]],
[[ 47, 64, 46, 255],
[ 63, 87, 51, 255],
[107, 141, 81, 255],
...,
[ 51, 61, 50, 255],
[ 44, 54, 45, 255],
[ 45, 57, 47, 255]],
...,
[[ 86, 105, 45, 255],
[ 76, 97, 37, 255],
[ 72, 91, 35, 255],
...,
[ 32, 45, 24, 255],
[ 26, 38, 19, 255],
[ 17, 28, 15, 255]],
[[ 72, 90, 39, 255],
[ 65, 83, 31, 255],
[ 66, 82, 33, 255],
...,
[ 17, 22, 16, 255],
[ 15, 20, 14, 255],
[ 16, 21, 17, 255]],
[[ 53, 70, 25, 255],
[ 55, 68, 22, 255],
[ 62, 71, 26, 255],
...,
[ 13, 15, 14, 255],
[ 17, 19, 18, 255],
[ 19, 21, 20, 255]]], dtype=uint8)
shape() method lets you know the dimensions of the array.
np.shape(img)
(219, 329, 4)
This means that img is a three-dimensional array with 219 x 329 x 4 numbers. If you look at the image, you can easily see that 219 and 329 are the dimensions (height and width in terms of the number of pixels) of the image. What is 4?
We can actually create our own small image to investigate. Let's create a 3x3 image.
myimg = np.array([ [[1,0,0,1], [1,1,1,1], [1,1,1,1]],
[[1,1,1,1], [1,1,1,1], [1,0,0,1]],
[[1,1,1,1], [1,1,1,1], [1,0,1,0.5]] ])
plt.imshow(myimg)
<matplotlib.image.AxesImage at 0x7f93a9410fe0>
Q: Play with the values of the matrix, and explain what are each of the four dimensions (this matrix is 3x3x4) below.
myimg[1,1] = [0,1,0,1]
plt.imshow(myimg)
# Results in the middle pixel to be green
<matplotlib.image.AxesImage at 0x7f93a8fbbfb0>
myimg[0] = [1, 0, 0, 1]
plt.imshow(myimg)
<matplotlib.image.AxesImage at 0x7f93a8ed5e80>
myimg[0] = [1, 0, 0, 0.2]
plt.imshow(myimg)
<matplotlib.image.AxesImage at 0x7f93a8ed4d10>
myimg[0] = [1, 0, 0, 0.7]
plt.imshow(myimg)
<matplotlib.image.AxesImage at 0x7f93a8ebbf80>
myimg[0,2] = [0,0,1,0.5]
plt.imshow(myimg)
<matplotlib.image.AxesImage at 0x7f93a8ebb110>
myimg[2,0] = [0,0,1,1]
plt.imshow(myimg)
<matplotlib.image.AxesImage at 0x7f93a8465220>
myimg[2,0] = [0.5,0.5,0.5,1]
plt.imshow(myimg)
<matplotlib.image.AxesImage at 0x7f93a9000920>
myimg[2,1] = [1,0,1,1]
plt.imshow(myimg)
<matplotlib.image.AxesImage at 0x7f93a851e060>
myimg[1,0] = [0,1,1,1]
plt.imshow(myimg)
<matplotlib.image.AxesImage at 0x7f93a838fd70>
myimg[0,0] = [1,1,0,1]
plt.imshow(myimg)
<matplotlib.image.AxesImage at 0x7f93a8f7d520>
- Conclusion
- By experimenting with different values in the image array, I was able to understand more clearly what each of the four dimensions represents. Each of these four numbers corresponds to red, green, blue, and transparency. By adjusting these values and observing the changes when displaying the image with plt.imshow, I noticed that changing the first number affects the red color, the second affects green, the third affects blue, and the fourth controls the transparency of each pixel.
Applying other colormaps¶
Let's assume that the first value of the four dimensions represents some data of your interest. You can obtain height x width x 1 matrix by doing img[:,:,0], which means give me the all of the first dimension (:), all of the second dimension (:), but only the first one from the last dimension (0).
plt.pcolormesh(img[:,:,0], cmap=plt.cm.viridis)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f93a8e8ef90>
Q: Why is it flipped upside down? Take a look at the previous imshow example closely and compare the axes across these two displays. Let's flip the figure upside down to show it properly. This function numpy.flipud() may be handy.
# TODO: put your code here
# YOUR SOLUTION HERE
plt.pcolormesh(np.flipud(img[:,:,0]), cmap=plt.cm.viridis)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f93a81db050>
Q: Try another sequential colormap here.
# TODO: put your code here
# YOUR SOLUTION HERE
plt.pcolormesh(np.flipud(img[:,:,0]), cmap=plt.cm.Set2)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f93a3727470>
Q: Try a diverging colormap, say coolwarm.
# TODO: put your code here
# YOUR SOLUTION HERE
plt.pcolormesh(np.flipud(img[:,:,0]), cmap=plt.cm.coolwarm)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f93a3f36bd0>
Although there are clear choices such as viridis for quantitative data, you can come up with various custom colormaps depending on your application. For instance, take a look at this video about colormaps for Oceanography: https://www.youtube.com/watch?v=XjHzLUnHeM0 There is a colormap designed specifically for the oxygen level, which has three regimes.
Adjusting a plot¶
x = np.linspace(0, 3*np.pi)
plt.xlabel("Some variable")
plt.ylabel("Another variable")
plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x))
[<matplotlib.lines.Line2D at 0x7f93a36125d0>]
You can change the size of the whole figure by using figsize option. You specify the horizontal and vertical dimension in inches.
plt.figure(figsize=(4,3))
plt.xlabel("Some variable")
plt.ylabel("Another variable")
plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x))
[<matplotlib.lines.Line2D at 0x7f93a366bf80>]
A very common mistake is making the plot too big compared to the labels and ticks.
plt.figure(figsize=(80, 20))
plt.xlabel("Some variable")
plt.ylabel("Another variable")
plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x))
[<matplotlib.lines.Line2D at 0x7f93a36cf950>]
Once you shrink this plot into a reasonable size, you cannot read the labels anymore! Actually this is one of the most common comments that I provide to my students!
You can adjust the range using xlim and ylim
plt.figure(figsize=(4,3))
plt.xlabel("Some variable")
plt.ylabel("Another variable")
plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x))
plt.xlim((0,4))
plt.ylim((-0.5, 1))
(-0.5, 1.0)
You can adjust the ticks.
plt.figure(figsize=(4,3))
plt.xlabel("Some variable")
plt.ylabel("Another variable")
plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x))
plt.xticks(np.arange(0, 10, 4))
([<matplotlib.axis.XTick at 0x7f93a35a8110>, <matplotlib.axis.XTick at 0x7f93a35a8cb0>, <matplotlib.axis.XTick at 0x7f93a35794c0>], [Text(0, 0, '0'), Text(4, 0, '4'), Text(8, 0, '8')])
colors, linewidth, and so on.
plt.figure(figsize=(7,4))
plt.xlabel("Some variable")
plt.ylabel("Another variable")
plt.plot(x, np.sin(x), color='red', linewidth=5, label="sine")
plt.plot(x, np.cos(x), label='cosine')
plt.legend(loc='lower left')
<matplotlib.legend.Legend at 0x7f93a3596a80>
For more information, take a look at this excellent tutorial: https://github.com/rougier/matplotlib-tutorial
Q: Now, pick an interesting dataset (e.g. from vega_datasets package) and create a plot. Adjust the size of the figure, labels, colors, and many other aspects of the plot to obtain a nicely designed figure. Explain your rationales for each choice.
# YOUR SOLUTION HERE
from vega_datasets import data
cars = data.cars()
cars.columns
Index(['Name', 'Miles_per_Gallon', 'Cylinders', 'Displacement', 'Horsepower',
'Weight_in_lbs', 'Acceleration', 'Year', 'Origin'],
dtype='object')
# Filter the dataset to include only numerical columns
numerical_columns = cars.select_dtypes(include=['float64', 'int64']).columns
# Loop through each pair of numerical columns and plot scatter plots
for i, col1 in enumerate(numerical_columns):
for j, col2 in enumerate(numerical_columns):
if col1 != col2: # Only plot if columns are different
plt.figure(figsize=(6, 4))
plt.scatter(cars[col1], cars[col2], alpha=0.5, edgecolor='k')
plt.title(f'{col1} vs {col2}')
plt.xlabel(col1)
plt.ylabel(col2)
plt.tight_layout() # Make plots neatly arranged without overlapping labels
plt.show()
Based on the scatter plot analysis, I chose Miles per Gallon, Horsepower, and Cylinders because:
- Miles per Gallon and Horsepower show a logistic relationship, where higher horsepower leads to lower fuel efficiency.
- Miles per Gallon and Cylinders reveal that cars with specific MPG ranges have a consistent number of cylinders, with more cylinders corresponding to lower MPG.
- Similar to the second bullet point, Horsepower and Cylinders show that cars with lower horsepower maintain a consistent number of cylinders, typically fewer, while higher horsepower cars tend to have more cylinders.
fig, ax = plt.subplots(figsize=(12, 8))
scatter = ax.scatter(
cars['Horsepower'], # x-axis data
cars['Miles_per_Gallon'], # y-axis data
c= cars['Cylinders'], # Data for color-coding
cmap=plt.cm.viridis, # Color map
)
cbar = fig.colorbar(scatter, ax=ax)
cbar.set_label('Cylinders')
ax.set_title('Relationship Between Horsepower and Miles per Gallon')
ax.set_xlabel('Horsepower')
ax.set_ylabel('Miles per Gallon')
plt.plot()
[]
SVG¶
First of all, think about various ways to store an image, which can be a beautiful scenary or a geometric shape. How can you efficiently store them in a computer? Consider pros and cons of different approaches. Which methods would work best for a photograph? Which methods would work best for a blueprint or a histogram?
There are two approaches. One is storing the color of each pixel as shown above. This assumes that each pixel in the image contains some information, which is true in the case of photographs. Obviously, in this case, you cannot zoom in more than the original resolution of the image (if you're not in the movie). Also if you just want to store some geometric shapes, you will be wasting a lot of space. This is called raster graphics.
Another approach is using vector graphics, where you store the instructions to draw the image rather than the color values of each pixel. For instance, you can store "draw a circle with a radius of 5 at (100,100) with a red line" instead of storing all the red pixels corresponding to the circle. Compared to raster graphics, vector graphics won't lose quality when zooming in.
Since a lot of data visualization tasks are about drawing geometric shapes, vector graphics is a common option. Most libraries allow you to save the figures in vector formats.
On the web, a common standard format is SVG. SVG stands for "Scalable Vector Graphics". Because it's really a list of instructions to draw figures, you can create one even using a basic text editor. What many web-based drawing libraries do is simply writeing down the instructions (SVG) into a webpage, so that a web browser can show the figure. The SVG format can be edited in many vector graphics software such as Adobe Illustrator and Inkscape. Although we rarely touch the SVG directly when we create data visualizations, I think it's very useful to understand what's going on under the hood. So let's get some intuitive understanding of SVG.
You can put an SVG figure by simply inserting a <svg> tag in an HTML file. It tells the browser to reserve some space for a drawing. For example,
<svg width="200" height="200">
<circle cx="100" cy="100" r="22" fill="yellow" stroke="orange" stroke-width="5"/>
</svg>
This code creates a drawing space of 200x200 pixels. And then draw a circle of radius 22 at (100,100). The circle is filled with yellow color and stroked with 5-pixel wide orange line. That's pretty simple, isn't it? Place this code into an HTML file and open with your browser. Do you see this circle?
Another cool thing is that, because svg is an HTML tag, you can use CSS to change the styles of your shapes. You can adjust all kinds of styles using CSS:
<head>
<style>
.krypton_sun {
fill: red;
stroke: orange;
stroke-width: 10;
}
</style>
</head>
<body>
<svg width="500" height="500">
<circle cx="200" cy="200" r="50" class="krypton_sun"/>
</svg>
</body>
This code says "draw a circle with a radius 50 at (200, 200), with the style defined for krypton_sun". The style krypton_sun is defined with the <style> tag.
There are other shapes in SVG, such as ellipse, line, polygon (this can be used to create triangles), and path (for curved and other complex lines). You can even place text with advanced formatting inside an svg element.
Exercise:¶
Let's reproduce the symbol for the Deathly Hallows (as shown below) with SVG. It doesn't need to be a perfect duplication (an equilateral triangle, etc), just be visually as close as you can. What's the most efficient way of drawing this? Color it in the way you like. Upload this file to canvas.

from IPython.display import SVG, display
svg_code = '''
<svg width="200" height="190">
<!-- Triangle -->
<polygon points="100,10 20,170 180,170" fill="none" stroke="pink" stroke-width="5"/>
<!-- Circle -->
<circle cx="100" cy="120" r="45" fill="lightgrey" stroke="pink" stroke-width="5"/>
<!-- Line -->
<line x1="100" y1="10" x2="100" y2="170" stroke="pink" stroke-width="5"/>
</svg>
'''
display(SVG(svg_code))
# with open("deathly_hallows.svg", "w") as file:
# file.write(svg_code)
